import torch
import config

from llama_cpp import Llama
from typing import Optional, List, Mapping, Any
from langchain.llms.base import LLM

class model_llama2_local(LLM):
    model_llama: object
    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        if torch.cuda.is_available():
            print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
        else:
            print("CUDA is not available. Will use CPU...")
        for llm in config.get_config().models.llms:
            if llm.name == "llama2-7b-ggml" and llm.enable == 1:
                self.model_llama = llm
        if self.model_llama is None:
            raise Exception("llama2-7b-ggml is not enabled")
        prompt_length = len(prompt) + 5
        llm = Llama(model_path=self.model_llama.path, n_threads=4)
        response = llm(f"Q: {prompt} A: ", max_tokens=256)
        output = response['choices'][0]['text'].replace('A: ', '').strip()
        return output[prompt_length:]
    
    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return {"name_of_model": "llama2"}

    @property
    def _llm_type(self) -> str:
        return "custom"